백준 1761번

정점들의 거리

Posted by ChoiCube84 on July 11, 2024 · 18 mins read

백준 1761번

오늘 풀어본 문제는 백준의 1761번 문제1이다. 문제 풀이에 사용한 언어는 C++ 이다.

solved.ac 기준 CLASS

CLASS 6

문제 정보

이 문제의 내용과 조건은 다음과 같다.

문제

$N$ $(2 \le N \le 40,000)$ 개의 정점으로 이루어진 트리가 주어지고 $M$ $(1 \le M \le 10,000)$ 개의 두 노드 쌍을 입력받을 때 두 노드 사이의 거리를 출력하라.

입력

첫째 줄에 노드의 개수 $N$ 이 입력되고 다음 $N-1$ 개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 $M$ 이 주어지고, 다음 $M$ 개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩 입력된다. 두 점 사이의 거리는 10,000보다 작거나 같은 자연수이다.

정점은 $1$ 번부터 $N$ 번까지 번호가 매겨져 있다.

출력

$M$ 개의 줄에 차례대로 입력받은 두 노드 사이의 거리를 출력한다.

풀이과정

1번째 시도

두 정점의 거리를 구하기 위해서는, 두 정점 사이의 경로를 알아야 하는데 이 과정에서 ‘최소 공통 조상’을 반드시 거칠 수 밖에 없다는 걸 생각해냈고, 기존에 풀었던 백준 11438번 문제에서 짜뒀던 Sparse Table을 이용한 LCA를 찾는 코드를 활용하기로 했다.

이 문제에서는 단순히 LCA를 구하는 것이 아닌, 두 정점간의 거리를 구해야 하기 때문에, 각 노드에서 LCA까지의 거리를 알아내어 이를 더하는 방식으로 해결할 수 있을 것이라고 생각하였다. 그래서 LCA를 추적하는 함수에서 LCA까지의 거리도 함께 추적하도록 하였다.

코드는 다음과 같이 작성하였다.

#include <bits/extc++.h>

using namespace std;

using ll = long long int;
using ull = unsigned long long int;

using pll = pair<ll, ll>;

vector<vector<pll>> tree(40001);

vector<int> depth(40001, 0);
vector<vector<pll>> parent(40001, vector<pll>(17));
vector<bool> visited(40001, false);

void setImmediateParent(int currentNode, int currentDepth);
void setParentSparseTable(int N);
int findDistance(int nodeA, int nodeB);

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int N, M, nodeA, nodeB, weight;

    cin >> N;

    for (int i=0; i<N-1; i++) {
        cin >> nodeA >> nodeB >> weight;

        tree[nodeA].emplace_back(nodeB, weight);
        tree[nodeB].emplace_back(nodeA, weight);
    }

    setImmediateParent(1, 0);
    setParentSparseTable(N);

    cin >> M;

    for (int i=0; i<M; i++) {
        cin >> nodeA >> nodeB;
        cout << findDistance(nodeA, nodeB) << '\n';
    }

    return 0;
}

void setImmediateParent(int curr, int currentDepth) {
    visited[curr] = true;
    depth[curr] = currentDepth;

    for (auto [next, dist] : tree[curr]) {
        if (visited[next] == false) {
            parent[next][0].first = curr;
            parent[next][0].second = dist;

            setImmediateParent(next, currentDepth + 1);
        }
    }
}

void setParentSparseTable(int N) {
    int maximumHeight = 0;

    for (int tempN=N; tempN>1; maximumHeight++) {
        tempN >>= 1;
    }

    for (int i=1; i<=maximumHeight; i++) {
        for (int j=1; j<=N; j++) {
            parent[j][i].first = parent[parent[j][i-1].first][i-1].first;
            parent[j][i].second = parent[j][i-1].second + parent[parent[j][i-1].first][i-1].second;
        }
    }
}

int findDistance(int nodeA, int nodeB) {
    if (depth[nodeA] < depth[nodeB]) {
        swap(nodeA, nodeB);
    }

    int distFromNodeA = 0;
    int distFromNodeB = 0;

    int difference = depth[nodeA] - depth[nodeB];
    int power = 0;

    while (difference > 0) {
        if (difference % 2 == 1) {
            distFromNodeA += parent[nodeA][power].second;
            nodeA = parent[nodeA][power].first;
        }
        difference >>= 1;
        power++;
    }

    if (nodeA == nodeB) {
        return distFromNodeA;
    }
    else {
        for (int i=16; i>=0; i--) {
            if (parent[nodeA][i].first != 0 && parent[nodeA][i].first != parent[nodeB][i].first) {
                distFromNodeA += parent[nodeA][i].second;
                distFromNodeB += parent[nodeB][i].second;
                
                nodeA = parent[nodeA][i].first;
                nodeB = parent[nodeB][i].first;
            }
            else {
                for (int j=0; j<=i; j++) {
                    if (parent[nodeA][j].first == parent[nodeB][j].first) {
                        distFromNodeA += parent[nodeA][j].second;
                        distFromNodeB += parent[nodeB][j].second;
                        break;
                    }
                }
                break;
            }
        }

        return distFromNodeA + distFromNodeB;
    }
}

실행 결과 ‘틀렸습니다’ 가 떴다.

2번째 시도

문제가 되었던 부분은, 마지막에 최소 공통 조상을 추적하는 과정에서 ‘최소’가 아닌 공통 조상까지의 거리를 구하게 될 수 있다는 점이었다. 이는 두 노드의 조상이 같을 때, 최소 공통 조상을 찾는 코드를 잘못 작성했기 때문이었다.

기존의 코드는 $2^i$ 차 부모가 같으면 $j$ 를 $0$ 부터 $i$ 까지 순회하며 $2^j$ 차 부모가 같아지는 최소의 $j$ 로 LCA를 찾았는데, 각 $j$ 는 $2^j$ 차 부모 밖에 표현하지 못하기 때문에 ($2$의 거듭제곱이 아닌 수) 차 부모는 찾아낼 수 없던 것이다. 그래서 $2^j$ 차 부모가 같아질 때, $2^{j-1}$ 차 부모까지의 거리를 먼저 계산하고, 그 부모들의 LCA를 다시 추적해나가는 방식으로 해결을 시도해보았다.

코드는 다음과 같이 수정하였다.

#include <bits/extc++.h>

using namespace std;

using ll = long long int;
using ull = unsigned long long int;

using pll = pair<ll, ll>;

vector<vector<pll>> tree(40001);

vector<int> depth(40001, 0);
vector<vector<pll>> parent(40001, vector<pll>(17));
vector<bool> visited(40001, false);

void setImmediateParent(int currentNode, int currentDepth);
void setParentSparseTable(int N);
int findDistance(int nodeA, int nodeB);

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int N, M, nodeA, nodeB, weight;

    cin >> N;

    for (int i=0; i<N-1; i++) {
        cin >> nodeA >> nodeB >> weight;

        tree[nodeA].emplace_back(nodeB, weight);
        tree[nodeB].emplace_back(nodeA, weight);
    }

    setImmediateParent(1, 0);
    setParentSparseTable(N);

    cin >> M;

    for (int i=0; i<M; i++) {
        cin >> nodeA >> nodeB;
        cout << findDistance(nodeA, nodeB) << '\n';
    }

    return 0;
}

void setImmediateParent(int curr, int currentDepth) {
    visited[curr] = true;
    depth[curr] = currentDepth;

    for (auto [next, dist] : tree[curr]) {
        if (visited[next] == false) {
            parent[next][0].first = curr;
            parent[next][0].second = dist;

            setImmediateParent(next, currentDepth + 1);
        }
    }
}

void setParentSparseTable(int N) {
    int maximumHeight = 0;

    for (int tempN=N; tempN>1; maximumHeight++) {
        tempN >>= 1;
    }

    for (int i=1; i<=maximumHeight; i++) {
        for (int j=1; j<=N; j++) {
            parent[j][i].first = parent[parent[j][i-1].first][i-1].first;
            parent[j][i].second = parent[j][i-1].second + parent[parent[j][i-1].first][i-1].second;
        }
    }
}

int findDistance(int nodeA, int nodeB) {
    if (depth[nodeA] < depth[nodeB]) {
        swap(nodeA, nodeB);
    }

    int distFromNodeA = 0;
    int distFromNodeB = 0;

    int difference = depth[nodeA] - depth[nodeB];
    int power = 0;

    while (difference > 0) {
        if (difference % 2 == 1) {
            distFromNodeA += parent[nodeA][power].second;
            nodeA = parent[nodeA][power].first;
        }
        difference >>= 1;
        power++;
    }

    if (nodeA == nodeB) {
        return distFromNodeA;
    }
    else {
        for (int i=16; i>=0; i--) {
            if (parent[nodeA][i].first != 0) {
                if (parent[nodeA][i].first != parent[nodeB][i].first) {
                    distFromNodeA += parent[nodeA][i].second;
                    distFromNodeB += parent[nodeB][i].second;

                    nodeA = parent[nodeA][i].first;
                    nodeB = parent[nodeB][i].first;
                }
                else {
                    for (int j=0; j<=i; j++) {
                        if (parent[nodeA][j].first == parent[nodeB][j].first) {
                            if (j <= 1) {
                                distFromNodeA += parent[nodeA][j].second;
                                distFromNodeB += parent[nodeB][j].second;
                                break;
                            }
                            else {
                                distFromNodeA += parent[nodeA][j-1].second;
                                distFromNodeB += parent[nodeB][j-1].second;

                                nodeA = parent[nodeA][j-1].first;
                                nodeB = parent[nodeB][j-1].first;

                                j = -1;
                            }
                        }
                    }
                    break;
                }
            }
        }

        return distFromNodeA + distFromNodeB;
    }
}

그러자 모든 테스트 케이스를 통과하고 ‘맞았습니다’가 나오는 것을 확인할 수 있었다.

마무리

LCA 알고리즘은 상당히 오랜만에 쓰는 알고리즘이라 제대로 활용할 수 있을지 걱정했지만, 생각보다 간단하게 해결된 것 같다. CLASS 6 문제들을 풀다 보면 종종 보게될 것 같은데, 그 문제들도 쉽게 풀렸으면 좋겠다.

오늘의 PS는 여기까지!


1: https://www.acmicpc.net/problem/1761